// unet
#include "cunet.h"
#include "partialconv.h"

// nanovoid sim
#include "irradiation_model.h"

// data loading
#include "load_data.h"

// torch and other utils
#include <torch/torch.h>
// #include <torch/script.h>
#include <iostream>
// #include <boost/filesystem.hpp>
#include <memory>

// time issue
#include <chrono>

// namespace fs = boost::filesystem;
// namespace idx = torch::indexing;
// static idx::Slice all_select(idx::None, idx::None, idx::None);
// static idx::Slice one_one_select(1, -1, idx::None);
// static bool debug_on = true;


torch::Tensor stitch_by_mask(torch::Tensor cv1, torch::Tensor cv_ori, torch::Tensor mask1, torch::Tensor ul1) {
    torch::Tensor ul = ul1.unsqueeze(-1).unsqueeze(-1);
    ul = ul.repeat({1, 128, 128});
    torch::Tensor mask2 = mask1 * ul + (1.0 - mask1) * (1.0 - ul);
    return cv_ori * mask2 + cv1 * (1.0 - mask2);
}

int main(int argc, char **argv) {

    int seed;
    char* run_idx;

    if (argc > 2) {
        run_idx = argv[1];
        seed = std::stoi(argv[2]);
    }
    else {
        run_idx = "";
        seed = 4321;
    }

    char* cv_path = "../../data/irradiation_v3/cv_all_data.pt";
    char* ci_path = "../../data/irradiation_v3/ci_all_data.pt";
    char* eta_path = "../../data/irradiation_v3/eta_all_data.pt";
    char* video_path = "../../data/irradiation_v3/video_all_data.pt";
    // char* filename_pkl = "container_all_data.pt";
    // char* video_pkl = "container_video_data.pt";

    // char* best_ts_model_out = "best_ts_model_out_v3.0_ef8"
    // char* best_unet_model_out = ""
    // char* best_unet_model_pretrained = ""

    char* best_ts_model_pretrained = NULL;

    Param param;

    param.lr = 1e-1;
    param.lr2 = 1e-3;
    param.lambda1 = 10.0;
    param.lambda2 = 10.0;
    param.skip_step = 30;
    param.batch_size = 32;
    param.epoch = 1;

    param.embedding_features = 8;

    double min_loss = 1000.0;

    torch::manual_seed(seed);
    torch::cuda::manual_seed(seed);
    torch::autograd::AnomalyMode::set_enabled(true); 

    // skip for testing
    // IrradiationVideoDataset dataset(cv_path, ci_path, eta_path, video_path, param.skip_step);

    if (debug_on) {
        std::cout << "finish data loading" << std::endl;
    }

    ParameterSet para(0.0);

    para.energy_v0 = 3.90086937;
    para.energy_i0 = 0.20275441;
    para.kBT0 = -2.56907487;
    para.kappa_v0 = -0.38098839;
    para.kappa_i0 = -0.78148192;
    para.kappa_eta0 = 0.34326860;
    para.r_bulk0 = 4.99900007;
    para.r_surf0 = 9.99899960;
    para.p_casc0 = 0.00900000;
    para.bias0 = 0.29899999;
    para.vg0 = 0.00900000;
    para.diff_v0 = 1.15292716;
    para.diff_i0 = -0.19592035;
    para.L0 = -0.52126276;

    IrradiationSingleTimeStep ts_model(param.dt, param.dx, param.dy, param.eps, param._N, para);

    torch::Device device(torch::kCPU);

    ts_model->to(device);

    if (debug_on) {
        std::cout << "finish ts model init" << std::endl;
    }

    if (debug_on) {
        for (const auto& pair : ts_model->named_parameters()) {
            std::cout << pair.key() << ": " << pair.value() << std::endl;
            std::cout << "is_leaf: " << pair.value().is_leaf() << std::endl;
        }
    }

    IrradiationVideoDataset dataset(cv_path, ci_path, eta_path, video_path, param.skip_step);

    // torch::nn::MSELossOptions mseopt(torch::kNone);

    // mseopt.reduction()

    torch::nn::MSELoss mse(torch::nn::MSELossOptions(torch::kSum));

    if (debug_on) {
        std::cout << "finish mseloss init" << std::endl;
    }

    torch::optim::Adam optimizer(ts_model->parameters(), torch::optim::AdamOptions(param.lr));

    if (debug_on) {
        std::cout << "finish optim1 init" << std::endl;
    }

    // CUNet2dWithEmbeddingGen video2pf(3, 3, 32, 5, 3, true, true, true, true, false, false, param.embedding_features, dataset.get_len() + param.skip_step, 8, false);
    // for testing
    CUNet2dWithEmbeddingGen video2pf(3, 3, 32, 5, 3, true, true, true, true, false, false, param.embedding_features, dataset.get_len() + param.skip_step, 8, false);


    // video2pf.eval();
    if (debug_on) {
        std::cout << "finish video2pf init" << std::endl;
    }

    torch::optim::Adam optimizer2(video2pf->parameters(), torch::optim::AdamOptions(param.lr2));

    if (debug_on) {
        std::cout << "finish optim2 init" << std::endl;
    }

    torch::Tensor mask = torch::ones({128, 128}, torch::dtype(torch::kFloat32).requires_grad(false));

    mask.index_put_({idx::Slice(64, 128, idx::None), all_select}, 0.0);

    if (debug_on) {
        std::cout << "start training..." << std::endl;
    }

    for (int i = 0; i < param.epoch; ++ i) {
        double loss = 0.0;
        int total_size = 0;
        printf("epoch:\t%d\n", i);
        // if (debug_on) {
        //     std::cout << "len of data: " << dataset.get_len() << std::endl;
        // }
        for (int index = param.start_skip + 1; index < (dataset.get_len() - param.start_skip - param.skip_step*2); ++ index) { // should be batch in loader, iterate all data point in train set
            
            ReturnItem rt = dataset.get_item(index);
            
            // ground truth at time 0
            torch::Tensor cv1 = rt.rd.cv;
            torch::Tensor ci1 = rt.rd.ci;
            torch::Tensor eta1 = rt.rd.eta;

            torch::Tensor frame1 = rt.rd.v;
            // torch::Tensor indicies1 = torch::from_blob(&rt.rd.index, {1}, torch::kInt64);
            int indicies1 = rt.rd.index;

            torch::Tensor ul1 = rt.rd.ul;

            if (debug_on) {
                std::cout << "success get data" << std::endl;
                // std::cout << "frame1 size: " << frame1.sizes() << std::endl;
            }

            torch::Tensor pf = video2pf->forward(frame1, indicies1);

            if (debug_on) {
                std::cout << "success forward in unet" << std::endl;
            }

            // learned cv ci eta from unet
            torch::Tensor cv = pf.index({all_select, 0, all_select, all_select});
            torch::Tensor ci = pf.index({all_select, 1, all_select, all_select});
            torch::Tensor eta = pf.index({all_select, 2, all_select, all_select});
            // std::cout << "cv size: " << cv.sizes() << ", ci size: " << ci.sizes() << ", eta size: " << eta.sizes() << std::endl;

            eta = stitch_by_mask(eta, eta1, mask, ul1);

            if (debug_on) {
                std::cout << "success get learned cv ci eta from unet" << std::endl;
            }

            torch::Tensor concate_vals;
            auto start = std::chrono::high_resolution_clock::now();
            for (int j = 0; j < param.skip_step; ++ j) {
                
                concate_vals = ts_model->forward(cv, ci, eta);
                
                cv = concate_vals.index({0});
                ci = concate_vals.index({1});
                eta = concate_vals.index({2});
                cv.unsqueeze_(0);
                ci.unsqueeze_(0);
                eta.unsqueeze_(0);
                // std::cout << "concate_vals size: " << concate_vals.sizes() << std::endl;
                // std::cout << "cv size: " << cv.sizes() << ", ci size: " << ci.sizes() << ", eta size: " << eta.sizes() << std::endl;
            }
            auto stop = std::chrono::high_resolution_clock::now();
            auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(stop - start);
            std::cout << "time of ts model forward: " << duration.count() << "ms in " << param.skip_step << "steps" << std::endl;

            if (debug_on) {
                std::cout << "success forward in ts model" << std::endl;
            }

            torch::Tensor cv_ref = rt.rl.cv_ref;
            torch::Tensor ci_ref = rt.rl.ci_ref;
            torch::Tensor eta_ref = rt.rl.eta_ref;

            torch::Tensor frame2 = rt.rl.v_ref;

            // torch::Tensor indicies2 = torch::from_blob(&rt.rl.index_ref, {1}, torch::kInt);
            int indicies2 = rt.rl.index_ref;

            torch::Tensor ul2 = rt.rl.ul_ref;

            pf = video2pf(frame2, indicies2);

            if (debug_on) {
                std::cout << "success forward in unet, ref" << std::endl;
            }

            torch::Tensor cv_frame = pf.index({all_select, 0, all_select, all_select});
            torch::Tensor ci_frame = pf.index({all_select, 1, all_select, all_select});
            torch::Tensor eta_frame = pf.index({all_select, 2, all_select, all_select});

            if (debug_on) {
                std::cout << "success get learned cv ci eta from unet, ref" << std::endl;
            }

            ul2.unsqueeze_(-1).unsqueeze_(-1);
            ul2 = ul2.repeat({1, 128, 128});
            torch::Tensor mask2 = mask * ul2 + (1.0 - mask) * (1.0 - ul2);

            torch::Tensor cv_new = concate_vals.index({0});
            torch::Tensor ci_new = concate_vals.index({1});
            torch::Tensor eta_new = concate_vals.index({2});

            cv_new.unsqueeze_(0);
            ci_new.unsqueeze_(0);
            eta_new.unsqueeze_(0);

            // std::cout << "cv frame size: " << cv_frame.sizes() << ", cv new size: " << cv_new.sizes() << std::endl;
            // std::cout << "ci frame size: " << ci_frame.sizes() << ", ci new size: " << ci_new.sizes() << std::endl;
            // std::cout << "eta frame size: " << eta_frame.sizes() << ", eta new size: " << eta_new.sizes() << ", eta ref size: " << eta_ref.sizes() << std::endl;

            torch::Tensor cv_batch_loss = param.lambda2 * mse->forward(cv_frame, cv_new);
            torch::Tensor ci_batch_loss = param.lambda2 * mse->forward(ci_frame, ci_new);
            torch::Tensor eta_batch_loss = mse->forward(mask2 * eta_ref, mask2 * eta_new) + \
                         param.lambda1 * mse->forward(mask2 * eta_ref, mask2 * eta_frame) + \
                         param.lambda2 * mse->forward(eta_frame, eta_new);
            
            torch::Tensor batch_loss = cv_batch_loss + ci_batch_loss + eta_batch_loss;

            if (debug_on) {
                std::cout << "success get loss" << std::endl;
            }

            optimizer.zero_grad();
            optimizer2.zero_grad();

            if (debug_on) {
                std::cout << "success opt1 opt2 zero grad" << std::endl;
            }
            auto start_back = std::chrono::high_resolution_clock::now();
            batch_loss.backward();
            auto stop_back = std::chrono::high_resolution_clock::now();
            auto duration_back = std::chrono::duration_cast<std::chrono::milliseconds>(stop_back - start_back);
            std::cout << "time of ts model backward: " << duration_back.count() << "ms in " << param.skip_step << "steps" << std::endl;

            if (debug_on) {
                std::cout << "success batch loss backward" << std::endl;
            }

            optimizer.step();
            optimizer2.step();

            if (debug_on) {
                std::cout << "success opt1 opt2 step()" << std::endl;
            }

            int this_size = cv.size(0);
            loss += (batch_loss.item<double>());
            if (true) {
                std::cout << "batch loss: " << (batch_loss.item<valueType>()) << std::endl;
                std::cout << "cv loss: " << cv_batch_loss.item<valueType>() << ", ci loss: " \ 
                            << ci_batch_loss.item<valueType>() << ", eta loss: " << eta_batch_loss.item<valueType>() << std::endl;
                {
                    for (const auto& pair : ts_model->named_parameters()) {
                        std::cout << pair.key() << "'s grad: " << pair.value().grad() << std::endl;
                        std::cout << pair.key() << "'s value: " << pair.value() << std::endl;
                    }
                }
            }
            total_size += this_size;

        }

        loss /= total_size;
        printf("loss:\t%.8f\n", loss);

        if (loss < min_loss) {
            min_loss = loss;
            // save ts_model
            // save video2pf 
        }
        else {
            printf("Above min_loss\n");
        }
    }

    // testing
    // for (int j = 0; j < param.skip_step; ++ j) {
        
    //     concate_vals = ts_model->forward(cv, ci, eta);
        
    //     cv = concate_vals.index({0});
    //     ci = concate_vals.index({1});
    //     eta = concate_vals.index({2});
    //     cv.unsqueeze_(0);
    //     ci.unsqueeze_(0);
    //     eta.unsqueeze_(0);
    //     // std::cout << "concate_vals size: " << concate_vals.sizes() << std::endl;
    //     // std::cout << "cv size: " << cv.sizes() << ", ci size: " << ci.sizes() << ", eta size: " << eta.sizes() << std::endl;
    // }

    return 0;
}

